{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hide_input": false
   },
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastai.gen_doc.nbdoc import *\n",
    "from utils import *\n",
    "from fastai2.collab import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building and training a model in PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The basic training loop in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.ML_100k)\n",
    "ratings = pd.read_csv(path/'u.data', delimiter='\\t', header=None,\n",
    "                      names=['user','movie','rating','timestamp'])\n",
    "\n",
    "dls = CollabDataLoaders.from_df(ratings, item_name='movie', bs=64).cpu()\n",
    "n_users,n_movies  = len(dls.classes['user']),len(dls.classes['movie'])\n",
    "n_factors = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_embedding  = (torch.randn(n_users,  n_factors) * 0.1).requires_grad_()\n",
    "movie_embedding = (torch.randn(n_movies, n_factors) * 0.1).requires_grad_()\n",
    "user_bias  = torch.zeros(n_users).requires_grad_()\n",
    "movie_bias = torch.zeros(n_movies).requires_grad_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 2]), torch.Size([64, 1]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x,y = dls.one_batch()\n",
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64]), torch.Size([64]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "users,movies = x.T\n",
    "users.shape,movies.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "usr_emb = user_embedding[users]\n",
    "usr_b   = user_bias[users]\n",
    "mov_emb = movie_embedding[movies]\n",
    "mov_b   = movie_bias[movies]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 5])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "usr_emb.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = (usr_emb * mov_emb).sum(dim=1) + usr_b + mov_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = 5 * torch.sigmoid(activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(x):\n",
    "    users,movies = x.T\n",
    "    usr_emb,usr_b = user_embedding[users],user_bias[users]\n",
    "    mov_emb,mov_b = movie_embedding[movies],movie_bias[movies]\n",
    "    activation = (usr_emb * mov_emb).sum(dim=1) + usr_b + mov_b\n",
    "    return 5 * torch.sigmoid(activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mse_loss(output, target): return (output-target).pow(2).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = mse_loss(output,y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "user_embedding.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-2\n",
    "user_embedding.data -= lr * user_embedding.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_one_epoch(model, dls, lr):\n",
    "    for x,y in dls.train:\n",
    "        output = model(x)\n",
    "        loss = mse_loss(output,y)\n",
    "        loss.backward()\n",
    "        for param in [user_embedding, movie_embedding, user_bias, movie_bias]:\n",
    "            param.data -= lr * param.grad\n",
    "            param.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate(model, dls):\n",
    "    n_elem,loss = 0,0.\n",
    "    with torch.no_grad():\n",
    "        for x,y in dls.valid:\n",
    "            output = model(x)\n",
    "            loss += mse_loss(output,y) * y.size(0)\n",
    "            n_elem += y.size(0)\n",
    "    return loss/n_elem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.3391)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "validate(model, dls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(model, dls, n_epoch, lr):\n",
    "    for i in progress_bar(range(n_epoch)):\n",
    "        fit_one_epoch(model, dls, lr)\n",
    "        val_loss = validate(model, dls)\n",
    "        print(f'Epoch {i+1}, validation loss: {val_loss:.6f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='5' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [5/5 00:35<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, validation loss: 1.759184\n",
      "Epoch 2, validation loss: 1.535529\n",
      "Epoch 3, validation loss: 1.434513\n",
      "Epoch 4, validation loss: 1.382121\n",
      "Epoch 5, validation loss: 1.351903\n"
     ]
    }
   ],
   "source": [
    "fit(model, dls, 5, 3e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Refactoring the model with nn.Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DotProductBias(Module):\n",
    "    def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n",
    "        self.user_factors = Embedding(n_users, n_factors)\n",
    "        self.user_bias = Embedding(n_users, 1)\n",
    "        self.movie_factors = Embedding(n_movies, n_factors)\n",
    "        self.movie_bias = Embedding(n_movies, 1)\n",
    "        self.y_range = y_range\n",
    "        \n",
    "    def forward(self, x):\n",
    "        users = self.user_factors(x[:,0])\n",
    "        movies = self.movie_factors(x[:,1])\n",
    "        res = (users * movies).sum(dim=1, keepdim=True) \n",
    "        res += self.user_bias(x[:,0]) + self.movie_bias(x[:,1])\n",
    "        return sigmoid_range(res, *self.y_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DotProductBias(n_users, n_movies, n_factors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = F.mse_loss(out, y.float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = nn.MSELoss()\n",
    "loss = loss_func(out, y.float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = MSELossFlat()\n",
    "loss = loss_func(out, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_one_epoch(model, dls, lr):\n",
    "    for x,y in dls.train:\n",
    "        output = model(x)\n",
    "        loss = loss_func(output,y)\n",
    "        loss.backward()\n",
    "        for param in model.parameters(): param.data -= lr * param.grad\n",
    "        model.zero_grad()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Going on the GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = out.cuda()\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = out.to(default_device())\n",
    "model = model.to(default_device())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=5)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "default_device()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = CollabDataLoaders.from_df(ratings, item_name='movie', bs=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_one_epoch(model, dls, 1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using torch.nn.optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = optim.SGD(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_one_epoch(model, opt, dls):\n",
    "    for x,y in dls.train:\n",
    "        output = model(x)\n",
    "        loss = loss_func(output,y)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = MSELossFlat(reduction='sum')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate(model, dls):\n",
    "    n_elem,loss = 0,0.\n",
    "    with torch.no_grad():\n",
    "        for x,y in dls.valid:\n",
    "            output = model(x)\n",
    "            loss += loss_func(output,y)\n",
    "            n_elem += y.size(0)\n",
    "    return loss/n_elem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(model, dls, n_epoch, lr, opt_func=optim.SGD):\n",
    "    opt = opt_func(model.parameters(), lr=lr)\n",
    "    for i in progress_bar(range(n_epoch)):\n",
    "        fit_one_epoch(model, opt, dls)\n",
    "        val_loss = validate(model, dls)\n",
    "        print(f'Epoch {i+1}, validation loss: {val_loss:.6f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DotProductBias(n_users, n_movies, n_factors).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='4' class='' max='4', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [4/4 00:36<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, validation loss: 1.034514\n",
      "Epoch 2, validation loss: 0.909997\n",
      "Epoch 3, validation loss: 0.883697\n",
      "Epoch 4, validation loss: 0.869609\n"
     ]
    }
   ],
   "source": [
    "fit(model, dls, 4, 1e-3, opt_func=optim.Adam)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a neural net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVoAAAFkCAYAAABhIfOrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAf5ElEQVR4nO3dd7QV1fXA8X0EfiAqoRclCoigohjBRhQbBuwiooiUJyIWRIJoRMVCshALYIEINnRBxBIEoxGMggYEG4ZikLLQSAsBQVCk1/n9kXg8Z3z3et97s+/cO/f7WYu19vnteXPPL8PdztucOWOCIBAAgJ794p4AACQdhRYAlFFoAUAZhRYAlFFoAUAZhRYAlFFoAUBZIgutMeYFY8waY8z3xpilxphr454Tys4YsyX0Z68xZmTc80LZJf07a5L4wIIxppmIfBkEwU5jzJEiMl1ELgiCYE68M0NUjDEHiMjXInJ+EATvxz0flE3Sv7OJvKMNgmBhEAQ7fxj+78/hMU4J0esoIutEZGbcE0HZJf07m8hCKyJijBlljNkmIktEZI2ITIl5SohWkYiMC5L4K1mBSvJ3NpGtgx8YY8qJSCsROVNEHgqCYHe8M0IUjDGHisgyEWkcBMGyuOeD6CT1O5vYO1oRkSAI9gZBMEtE6ovIjXHPB5HpLiKzKLLJk9TvbKILraO8JKjfA+kuImPjngRUJeo7m7hCa4ypbYy50hhzoDGmnDGmnYh0FpH34p4bys4Y82sROUREJsQ9F0SjEL6zievRGmNqicirInKc/Pc/JCtEZEQQBM/EOjFEwhjzlIhUDoKgW9xzQTQK4TubuEILALkmca0DAMg1FFoAUEahBQBlFFoAUEahBQBl5dMljTEsSYhREARG47xc13hpXVcRrm3cUl1b7mgBQBmFFgCUUWgBQBmFFgCUUWgBQBmFFgCUUWgBQBmFFgCUUWgBQBmFFgCUUWgBQBmFFgCUUWgBQFna3bsAIJuqVatm4549e3q5oUOH2nj79u1e7pJLLrHx1KlTlWZXetzRAoAyCi0AKEv7uvFc2UT4yiuvTDlu166dl9uwYYONly9fbuNrrrnGO27p0qURzlAHG38nExt/+0466SQbP/zwwzZu3bp1xufYtGmTjRcvXpzyuFdffdXGjz76aMbnzxQbfwNATCi0AKCMQgsAynK2RzthwgQbX3rppV7un//8p4137NiR8hyNGze28f777+/lbr/9dhuPHj261PPURI82mQq9R3v11Vd746eeesrG5cuXbsXp2rVrbbxt2zYv16hRIxsvWrTIxscee2ypPisderQAEBMKLQAoy9knwy677DIbX3755V7ujTfesPHu3btTnqNOnTo2fuWVV7ycu7Rj8+bNXu6FF14o2WSRkjH+b1Lukz81a9b0cu4SvC5duni5+vXr2/jjjz+28cyZM73j9u3bl3Iubjvqs88+83J79uxJ+XOI1qpVq7xxuXLlij3u888/98aPPfaYjVeuXOnlFixYYGN3qZeIyJIlS0o1zyhxRwsAyii0AKAsZ1cddOrUycYTJ070cqX5Ne+www7zxs8884yN9+7d6+XOO++8Ep9fQz6tOnCf0HPbPvvt5/+3vEePHlF/dKm4TyCJiNx77702TteOikKhrzoIC7eQfhBePRAep3LGGWd447feesvG7oqlU045JdMpZoxVBwAQEwotACij0AKAspzt0WorKiqy8bBhw7xcrVq1sj2dYsXdo23RooU3vvPOO1Me26ZNGxv/4he/yGge4af6vvnmGxuvWbMm5c998sknxcYiIn379rVx5cqVvVyzZs1SntN9UnD48OEpj4sCPdroXXTRRTYeMmSIlzv66KNt7D5l6i4TjQo9WgCICYUWAJTl7JNhiN8NN9zgjTt06FDic0ybNi3leMaMGV7O3Rgk/ORPpl588UUbH3jggV7u3XfftfEJJ5zg5Y444ohSfR7i8bvf/c4b33fffTYObyB166232vhvf/ub7sRS4I4WAJRRaAFAGYUWAJQVTI+2ZcuW3vjBBx+0sburE340cuRIb/zLX/7Sxm3btvVybu/rgQcesPGHH37oHZdud62oHXLIId74yCOPzNpno+wqVarkjX//+9/buE+fPimPnTx5spf7y1/+YuNdu3ZFOcWMcUcLAMootACgLNGtgwMOOMDGjzzyiJerWrWqjcePH5+1OeUTdzNlEZFLLrnExu6m6iIiX3/9tY3j+vUsLLzMJ7zcyzV//nzt6SAFd3P4jh07FhsXN3a57QJ3OZeIyPLly8s4w7LjjhYAlFFoAUAZhRYAlOV9jza8U1SvXr1s3LNnTxs3adLEO+6Pf/yjjT/44AOl2SWL23sNv2AvF6V66V9xpk6dqjgTuMK74z3xxBM2dt/OEfbdd9/Z+I477vBy7htTchF3tACgjEILAMrysnVw6qmn2vivf/2rl3NbCe6ykfAG5wcddJDS7JArwst8kD3hTdYvvvhiG7vtPZGfvjg1Fff7HN6UvmHDhjZetmxZxvPMFu5oAUAZhRYAlOXFO8PC/8LovrsqXQsgXevAFf4V89FHHy3pFFXE/c6wfFSjRg0bz54928s1aNDAxv/4xz+83DnnnGPjzZs360zufwrhnWGvvPKKN073VJdrxYoVNt62bZuXc7/rNWvW9HKnnXaajefNm5fxPKPGO8MAICYUWgBQRqEFAGV5sbwr3EN1d2GaOHGilxs0aJCN16xZY+PwiwYHDhxo4/BSEeSv7t2729jtyYaFNyTX7ssWmnfeeccbuz3acO6hhx6ysdtf3bRpk3ecu3vcpEmTvJy7EX2cPdpUuKMFAGUUWgBQlhetg/PPP98b79mzx8aLFy/2cjt27Cj2HEOGDPHGFSpUsHHfvn29nLth+NatW0s2WcTq8MMPz+i4F198UXkmhe25557zxq+++qqNt2zZ4uX27t1b7DnKl/fLU+/evVN+3sqVK0s6xazijhYAlFFoAUAZhRYAlOVFj/bTTz+N/JyjRo2y8fXXX+/lBg8ebONbbrkl8s8Gki78yHt4qVYq9evXt/GYMWO8nPuYdPjfYpYsWVLSKWYVd7QAoIxCCwDK8qJ1oGH9+vU2dncMEvF3gAJQdu5yyqpVq3q5sWPH2vjkk09OeZwr/HQZrQMAKHAUWgBQRqEFAGUF26OtW7eujcOPbX755ZfZng5KqWLFit74pJNOimkmhSe8O1rjxo1t3KlTJy/nvjzxrLPOyuj87r+jiIjcdNNNNp4yZYqX2759e0bnjAt3tACgjEILAMoKtnVw22232Ti8nCv84j7kLnfZkIhIy5YtUx67dOlSG4eX9BWC3/zmNzZu3ry5l+vZs6eNw7tmLVy40MbNmjWzcbVq1bzjqlevXqp5rV692sZdunSx8apVq7zjli9fXqrz5wLuaAFAGYUWAJQVTOvAfS+RiEifPn1sPHPmTC/37LPPZmVOKLs6depkfKz7Lql169ZpTCenXXXVVTbu0KGDl3PfwxeW6WbqrrVr13rj2bNn2zj8vq/x48fbeN++fSX+rHzAHS0AKKPQAoAyCi0AKEtcj7ZKlSo27ty5s4379evnHbdx40Yb9+/f38vxQsb8cc8996TMffXVV944/BLOQtOjRw8bL1q0yMuFn/LKxK5du7zx888/b+MNGzZ4OXcJVyHijhYAlFFoAUCZCb/bx0sakzoZI7c90LFjRy83ZMgQG9euXdvG4f8/3V+jxo0bF/UUIxEEgdE4b65e19LYuXOnN3afavrss8+8XIsWLbIyp5+jdV1FknVt81Gqa8sdLQAoo9ACgDIKLQAoy8sebaGgR/vz6NH6knRt8xE9WgCICYUWAJQl7skwFJbHH3/cG9966602/vjjj7M9HaBY3NECgDIKLQAoo9ACgDKWd+UwlnclE8u7kovlXQAQEwotAChL2zoAAJQdd7QAoIxCCwDKKLQAoCxxhdYYsyX0Z68xZmTc80LZGWNeMMasMcZ8b4xZaoy5Nu45IRpJv7aJ/scwY8wBIvK1iJwfBMH7cc8HZWOMaSYiXwZBsNMYc6SITBeRC4IgmBPvzFBWSb+2ibujDekoIutEZGbcE0HZBUGwMAiCHzagDf735/AYp4SIJP3aJr3QFonIuCDJt+0FxhgzyhizTUSWiMgaEZkS85QQkSRf28S2Dowxh4rIMhFpHATBsrjng+gYY8qJSCsROVNEHgqCYHe8M0JUknptk3xH211EZlFkkycIgr1BEMwSkfoicmPc80F0knptk15ox8Y9CagqLwnq48GTqGubyEJrjPm1iBwiIhPinguiYYypbYy50hhzoDGmnDGmnYh0FpH34p4byqYQrm0ie7TGmKdEpHIQBN3inguiYYypJSKvishx8t8bhBUiMiIIgmdinRjKrBCubSILLQDkkkS2DgAgl1BoAUAZhRYAlFFoAUAZhRYAlJVPl+SNmvHiLbjJxFtwk4u34AJATCi0AKCMQgsAyii0AKCMQgsAyii0AKCMQgsAyii0AKCMQgsAyii0AKCMQgsAyii0AKCMQgsAytLu3gWUVceOHb3xmDFjbLxixQovt3nzZhsvWbLEy02aNMnGs2bNsvGmTZsimSeyp1atWt64ZcuWNr7zzjtTHte0aVMbb9iwwcu5fz+++eYbL/fMMz++4zH8dy5buKMFAGUUWgBQlvZ142wiHK8kbPz92GOPeeO+ffuW+Zxr16618ZQpU7xc//79bfz999+X+bM0FMLG3x06dPDGbkugZs2aXu7QQw+1sVuPjPH/Zyptbv369TauW7fuz869LNj4GwBiQqEFAGUUWgBQVrA92vr169u4Z8+eXm7AgAE2rlSpkpc744wzbDxz5kyl2f1XEnq0++3n/7f8ggsusPGHH37o5fbs2WPjs88+28sdeeSRNr7llltsHO73uUu/zjvvPC+3devWTKetKqk92okTJ9q4ffv2Xs7tmy5evNjLrVy50savvfZayvO7S7jCfVi3B9y1a1cvV6NGDRvPnTvXxieeeGLKzyoterQAEBMKLQAoy/vWQcWKFb2x+yun2wKoU6eOd1y1atVsHH4CJZ233nrLxhdeeGHGP1caSWgdaGjQoIGNw+0Hd/nO9ddf7+XcJ4TilJTWQXgJ14QJE2wcritFRUU2DrcHtm3bFum8+vXr542HDRtm43nz5tmY1gEAJAiFFgCU5WXr4IgjjrDx4MGDvVx4E5MfpHuSpCTcX1Vbt25dqnNkitbBzwuvGEm3gUjDhg2zMqefk5TWQdinn35q43B7YMiQIaqfffrpp9t4+vTpXs79rl9++eU2dlcxRIXWAQDEhEILAMootACgLC96tPXq1fPG06ZNs7H7xFA69Gh/lCvXNQrh5X0LFy60caNGjbxc+Cm1uCS1R6vtqKOOsvFdd93l5U477TQbu7uBifhPoh1zzDFKs/sverQAEBMKLQAoy9l3hjVp0sTGM2bM8HK1a9fO9nSsyZMnx/bZ+KmDDz7YG4fbBcgv7lOaTz75pJdzN6pJ1wp0N44R+enmQnHgjhYAlFFoAUAZhRYAlMXao3U35O3Tp4+Xu++++2wcxdKs8C5P7mN64aUi6bgvekP2VKhQwcbu9b///vtT/oz7SCj0XXfddRkdd+mll3pjd/mVu2SyRYsW3nHudX/66ae93JIlS2w8fvx4L/fNN99kNC9N3NECgDIKLQAoy3rroG3btjZ+9tlnbXzIIYd4x6VrD6TLue+FeuCBB2w8atQo77i+fftmdL6SfDai065dO2/8yCOP2Hj37t02bt68uXfcjh07bNyjRw+l2UFE5E9/+pM37tKli43D3xO3/RfOudfazYVbht27d7dxuD2Q67ijBQBlFFoAUEahBQBl6j3a448/3hun68tGoVu3bjZ+/fXXU86jV69ekX82otOqVStv7O7clM7NN99s40WLFkU6J/jCbyjYsmVLRj8XXt7lPnab7t9Axo4da+OuXbt6OfcNDjNnzsxoHtnEHS0AKKPQAoAy9dbBTTfd5I2jaBfs3bvXxl988YWX27hxY7E/E376q7Tz2LRpU0bHHXbYYTauUqWKl1uwYEGpPruQDBo0yBu7v5befffdNg7/b7ty5UrVeeFH4Rcwhsep3HjjjSlzboto3LhxXs7d5D+8/M9dNuq+gFFE5yWMJcUdLQAoo9ACgDL1d4Y1btzYG7vv8+ncuXNG51i3bp03dlcTzJ49O6Nz7Nu3zxtn+oTX0qVLvXGm//odBd4ZVrwbbrjBxiNGjPByn3/+uY3DKxd27typO7EM8c6w0nFbBxMnTvRyTZs2tfH27du9nNvScJ8u08A7wwAgJhRaAFBGoQUAZeo92jjVqVPHxmvXrvVymfZo3ffFi/x0A3FN9Gh/3sKFC72x20Nv0KCBl8uVpV/0aKPn7u7Wr18/L+d+192lXxrLvujRAkBMKLQAoCzWd4Zpu+CCC2wcbhWkax24y8fmzZsX/cQQmfnz53tjt3UQXt6VK60DRM/dVMbd/F9E5M4777Tx8OHDbfz+++97x2m+W4w7WgBQRqEFAGUUWgBQlugebZs2bTI6bs2aNd74nnvusXH4cT7klnS99urVq2dxJoiT2191/41FxN+5z91Vz90OIHyOqHFHCwDKKLQAoCxxrYOrrrrKxpdddlnK49zNw4cOHerlwk8bIbc0atTIxh06dPByxvz4YM6KFSuyNifkjvbt23vjTJ8C1cQdLQAoo9ACgDIKLQAoS1yPtnnz5jauUKFCyuO+/fZbGz/++OOqc0LZ1atXz8ajR4+2caVKlbzj3Jc4zpkzR39iCeQugRIRad26tY3ffvttG69fvz5rcxLxH68+99xzvdwdd9xh49q1a3s59+0q9957r43nzp0b9RRT4o4WAJRRaAFAWd63DmrUqOGNzzrrLBu7S33cWETkgw8+0J0YRESkZs2a3rht27Y2fuedd7zc7t27bRx+qs/d2Nl9oie8dOfaa6+18ddff12KGRemgQMH2rhv375ezv2O9e7d28ZPP/10xue/7rrrUuZ69eqV0TnclzNWrlzZy7l/D8Ivc+3WrZuNw3/nsoU7WgBQRqEFAGV5/86wK664whu/9NJLxR63Z88eb1xUVGTjl19+OfqJRSAJ7wxbvHixN27atGmk57/lllu8cT6sIMmFd4aFVxbMnj3bxrVq1fJybo1wW3Dh2lHa3KpVq4rNhefornIIrxh47bXXbFySlkbUeGcYAMSEQgsAyii0AKAs75d3de7cOaPj3B6USO72ZZNm9erV3ri0PVr3xYru8qM333yzdBODx+2NZvoi0/D/3e2Nuj1TkfSbaqd6aWa6jbnz7UWb3NECgDIKLQAoy/vWQabefffduKdQkM4555y4p4BihDdFr1u3bkwzSU3zHV7Zxh0tACij0AKAMgotACjL+x7tF198kTLn7vg0aNCgLMwGAH6KO1oAUEahBQBleb97V5IlYfcu/FQu7N4FHezeBQAxodACgDIKLQAoo9ACgDIKLQAoo9ACgLK0y7sAAGXHHS0AKKPQAoAyCi0AKEtcoTXGbAn92WuMGRn3vFB2XNtkKoTrmvfbJIYFQXDgD7Ex5gAR+VpEJsQ3I0SFa5tMhXBdE3dHG9JRRNaJyMy4J4LIcW2TKZHXNemFtkhExgWsYUsirm0yJfK6JnYdrTHmUBFZJiKNgyBYFvd8EB2ubTIl+bom+Y62u4jMStoFg4hwbZMqsdc16YV2bNyTgAqubTIl9romsnVgjPm1iEwVkbpBEGyOez6IDtc2mZJ+XZN6R1skIpOSeMHAtU2oRF/XRN7RAkAuSeodLQDkDAotACij0AKAMgotACij0AKAsrS7dxljWJIQoyAIjMZ5ua7x0rquIlzbuKW6ttzRAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoAyCi0AKKPQAoCy8nFPQFPjxo1t/OWXX0Z+/hEjRti4Z8+eXu60006z8bx58yL/7LgddNBB3vjSSy+18ZQpU2z83Xffecft2bMn5TkrVKhg40qVKqU8rmrVqja+4oorvFzHjh1tXL9+fS/3wgsv2Pjuu+/2cnv37k35ecgvNWvW9MaDBg2yce/evW3s/n0QEenevbvanLijBQBlFFoAUJb3rYPwr5gjR460cadOnWw8depU77gbbrjBxuvXr8/48/r162fjLl262Hj79u3ecf/+978zPmc+atWqlTd+7LHHbNy/f/+UP7du3bqUuYMPPtjGRx99tJczxtg4CAIb79q1yztu48aNNl64cKGXGzBggI0XL17s5caNG5dyXvjR+eefb+N69epl/HM1atSw8VFHHeXltm7dauPPPvvMxvv27Ut5vkaNGnnj4447zsannnqql6tSpYqN3b87W7Zs+blpR4Y7WgBQRqEFAGUUWgBQZtyexU+SxqROxshd3vPJJ594ucMPP7zYn3F7fCL+sqN0y6+OOeYYb+wua6pYsaKNi4qKvOPGjx+f8pyZCoLA/PxRJadxXatXr27jAw44wMbhZW/7779/RucL97zffvvtYo/btGmTN160aJGNr7rqKi/nLudxl/mIiDz55JMZzSsKWtdVROfa/t///Z+NP/30UxuHvxtRC39n09WqTM2ZM8fGrVu39nI7d+4s8/lTXVvuaAFAGYUWAJTlxfKuk08+2Ru//vrrNg4/BeK2BNylIuFfQ9zlIGeeeWbKc7i/Non4y8k++ugjGyfx6a+ScJdVubH7VE4uCS/9QmruEroFCxbY2G0RifjfDffpQBF/Wdju3bu9XKqn8sLf2WXLltk4vDTroosuKvYcIiJLly61cfv27W0cRasgU9zRAoAyCi0AKKPQAoCynO3RHnvssTaePHmyl3OXd3377bde7oQTTrDxihUrUp6/QYMGNq5cubKXc3eEuv32273cv/71LxtfeeWVNk76I7f5qG3btt7Y/bviLlNC5rp27Zoyd/nll9t4woQJXs7t54Yfmw73bDNx8cUXe+N0Pdrp06fb+D//+U+JPysK3NECgDIKLQAoy5knw6699lpv/OCDD9rYfepIRGTatGk2HjhwoJcrza+E4U2sx44da+Pwk2buvF566aUSf1ZJ5NOTYbmiTp06NnaXA4mIvPLKKzbu0aNH1uYUlm9PhuUKdxeu1atXezm3/ffVV195OfcJNu0lXTwZBgAxodACgDIKLQAoi3V5l7sUa/DgwV7OXcK1fPlyL+cuvwq//K80wr3Wc88918bDhw9PeyxKZr/9fvxve7od9EvrlFNOsXH48elJkyZF/nnInm7dutk4vCTT7b0+8cQTKXNx4Y4WAJRRaAFAWaytA3dnp/AuXJ9//rmN+/bt6+WiaBdcffXVNj799NO93JtvvmnjYcOGlfmzClmzZs288fPPP2/jXr16ebmVK1eW+PzNmzf3xrfddpuNw0/r1apVy8b333+/l3OfRHSFNwQP70oFPW4bSERk6NChKY91d+pzXxSaK7ijBQBlFFoAUJb11kHHjh1tfPbZZ9vYfdeTiMiMGTNs/P7775f5c5s0aeKN3XbE/Pnzvdwbb7xh4/Xr15f5swvZjh07vPHBBx9s42xvlu62qsJ/pz788EMbu20r91dSZJf7hKaI/46+8Lvi3LqSi7ijBQBlFFoAUEahBQBlWe/RupsvjxkzxsbPPfecd1wU/Tt3yc7f//53L+c+eVavXj0vR182Ou5G6SL+cq+mTZt6uUsuuaTE5+/Tp483dl/o16ZNGy/HZt+5z13+d8QRR3g5d8Pw7t27e7nwbl65hjtaAFBGoQUAZTmz8XdU3HcTDRgwwMb33HOPd5z7K2d4E4pcwcbfxXNbQnPmzPFy7qY/RUVFWZtTSbDxt++yyy6zsbs5e7ly5bzjRo8ebePevXvrT6wU2PgbAGJCoQUAZRRaAFAW6+5dGtyX7t111102dh+rFREZP3581uaEaLkv8ixf3v8rPGLEiGxPByXkPoYt4j9q6y7PW7x4sXdcrvZlM8EdLQAoo9ACgLK8bx2Ed/i58MILbfzyyy/b+Oabb/aOi2LzcGSP+xTZTTfdZOMJEyZ4x82dOzdrc0LpPPDAA954//33L/a43/72t9mYTlZwRwsAyii0AKCMQgsAyvKyR9uqVSsbt2jRwsu5jxS7b2lwdw1D7nOX+YiIPPTQQzZetmyZja+55hrvuHSPlCM+3bp1s3GnTp1SHvf444/beOrUqapzyibuaAFAGYUWAJTlxe5d1atX98aTJ0+2sbuTk4jIwIEDbez+GpKPCnn3rtatW3vj9957z8b9+/e38ciRI7M2p6gUwu5d4ZehfvTRRzZ2N90X8Tf5P+WUU2y8Z88epdnpYfcuAIgJhRYAlOVF6+DFF1/0xu6/Wj788MNe7tlnn7Vx+H1V+abQWgeVKlWy8fTp071clSpVbHz00Udna0oqCqF18PHHH3vjE088MeWxDRo0sPGqVau0ppQVtA4AICYUWgBQRqEFAGU5+2RYu3btbBx+ksTt34V378r3vmwhO/74420cfuKvffv22Z4OSsjdkP1Xv/pVyuPCO+mtXr1abU65gjtaAFBGoQUAZTmzvOuKK67wxu67n8LvherVq5eNX3vtNd2JxSjpy7vKlSvnjefMmWPj8NOAhx56aFbmlA1JWd511FFHeWO3pVezZk0v524E1Lx5cy+3bdu26CcXE5Z3AUBMKLQAoIxCCwDKYl3e5fZ4hg4d6uV27dpl41mzZnm58Bj5qW7dut7Y7d0NHjw429NBCY0aNcobh/uyrhtvvNHGSerJZoo7WgBQRqEFAGWxtg7cJV3hd0Q9//zzNr7vvvuyNidkT8OGDb3x7t27bfznP/8529NBCb311lve+PTTT7fxH/7wBy83bdq0rMwpV3FHCwDKKLQAoIxCCwDKcuYRXPxU0h/BLVRJeQQXP8UjuAAQEwotAChL2zoAAJQdd7QAoIxCCwDKKLQAoIxCCwDKKLQAoIxCCwDK/h997Rykqrx3eQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x432 with 9 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from fastai.vision import *\n",
    "path = untar_data(URLs.MNIST_SAMPLE)\n",
    "dls = ImageDataLoaders.from_folder(path)\n",
    "dls.show_batch(rows=3, figsize=(6,6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_one_epoch(model, opt, dls, loss_func):\n",
    "    for x,y in dls.train:\n",
    "        output = model(x)\n",
    "        loss = loss_func(output,y)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "    \n",
    "def validate(model, dls, loss_func):\n",
    "    n_elem,loss,corrects = 0,0.,0.\n",
    "    with torch.no_grad():\n",
    "        for x,y in dls.valid:\n",
    "            output = model(x)\n",
    "            loss += loss_func(output, y, reduction='sum')\n",
    "            corrects += (output.argmax(dim=1) == y).float().sum()\n",
    "            n_elem += y.size(0)\n",
    "    return loss/n_elem,corrects/n_elem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(model, dls, n_epoch, lr, loss_func, opt_func=optim.SGD):\n",
    "    opt = opt_func(model.parameters(), lr=lr)\n",
    "    for i in progress_bar(range(n_epoch)):\n",
    "        fit_one_epoch(model, opt, dls, loss_func)\n",
    "        val_loss, acc = validate(model, dls, loss_func)\n",
    "        print(f'Epoch {i+1}, validation loss: {val_loss:.6f}, accuracy: {acc * 100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FirstNeuralNet(Module):\n",
    "    def __init__(self, n_in, n_hidden, n_out):\n",
    "        self.linear1 = nn.Linear(n_in, n_hidden)\n",
    "        self.linear2 = nn.Linear(n_hidden, n_out)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = x[:,0].view(x.size(0), -1)\n",
    "        activation = F.relu(self.linear1(x))\n",
    "        return self.linear2(activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='2' class='' max='2', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [2/2 00:05<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, validation loss: 0.035692, accuracy: 98.82%\n",
      "Epoch 2, validation loss: 0.020682, accuracy: 99.17%\n"
     ]
    }
   ],
   "source": [
    "model = FirstNeuralNet(28*28, 1000, 2).cuda()\n",
    "fit(model, dls, 2, 1e-3, F.cross_entropy, opt_func=optim.Adam)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cross entropy loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = dls.one_batch()\n",
    "pred = model(x)\n",
    "sm_pred = log_softmax(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-12.337810516357422, -3.4572341442108154, -0.00045327682164497674)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sm_pred[0,0].item(), sm_pred[1,0].item(), sm_pred[2,1].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1.2338e+01, -3.4572e+00, -4.5328e-04], device='cuda:5', grad_fn=<IndexBackward>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sm_pred[[0,1,2], [0,0,1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nll(input, target): return -input[range(target.shape[0]), target].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0014, device='cuda:5', grad_fn=<NegBackward>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = nll(sm_pred, y)\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logsumexp(x):\n",
    "    m = x.max(-1)[0]\n",
    "    return m + (x-m[:,None]).exp().sum(-1).log()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(5.9457, device='cuda:5', grad_fn=<SelectBackward>),\n",
       " tensor(5.9457, device='cuda:5', grad_fn=<SelectBackward>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logsumexp(pred)[0], pred.logsumexp(-1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0014, device='cuda:5', grad_fn=<NllLossBackward>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.nll_loss(F.log_softmax(pred, -1), y.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0014, device='cuda:5', grad_fn=<NllLossBackward>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.cross_entropy(pred, y.cuda())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dropout and Batchnorm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DropoutLayer(Module):\n",
    "    def __init__(self, p): self.p = p\n",
    "    \n",
    "    def forward(self, x):\n",
    "        if self.training:\n",
    "            mask = x.new(*x.size()).bernoulli_(1-self.p).div_(1-self.p)\n",
    "            return mask * x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batch normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchNormLayer(Module):\n",
    "    def __init__(self, nf, mom=0.1, eps=1e-5):\n",
    "        # NB: pytorch bn mom is opposite of what you'd expect!\n",
    "        self.mom,self.eps = mom,eps\n",
    "        self.mults = nn.Parameter(torch.ones (nf,))\n",
    "        self.adds  = nn.Parameter(torch.zeros(nf,))\n",
    "        self.register_buffer('vars',  torch.ones(1,nf))\n",
    "        self.register_buffer('means', torch.zeros(1,nf))\n",
    "\n",
    "    def update_stats(self, x):\n",
    "        m = x.mean(0, keepdim=True)\n",
    "        v = x.var (0, keepdim=True)\n",
    "        self.means.lerp_(m, self.mom)\n",
    "        self.vars.lerp_ (v, self.mom)\n",
    "        return m,v\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if self.training:\n",
    "            with torch.no_grad(): m,v = self.update_stats(x)\n",
    "        else: m,v = self.means,self.vars\n",
    "        x = (x-m) / (v+self.eps).sqrt()\n",
    "        return x*self.mults + self.adds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['mults', 'adds', 'vars', 'means'])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tst = BatchNormLayer(100)\n",
    "tst.state_dict().keys()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
